"""
自动获取包含FAST角点的手势切片，用于生成CNN图像数据集
"""
import os
import cv2
import random
import config
import CKA
from skimage import exposure
import numpy as np
import gesture_recognition_utility as gu


RESCALE_RATE = config.RESCALE_RATE
KERNEL_SIZE = config.KERNEL_SIZE
FAST_RADIUS = config.FAST_RADIUS
SET_MAX_LENGTH = config.SET_MAX_LENGTH
REVISE_WIDTH = config.REVISE_WIDTH
REVISE_HEIGHT = config.REVISE_HEIGHT
MODEL_WIDTH = 64
MODEL_HEIGHT = 64
VIDEO_FILE_DIR = './DataSets/Inputs/'
IS_DEX_TEST = True


if __name__ == '__main__':
    file_list = os.listdir(VIDEO_FILE_DIR)
    print(file_list)
    # 遍历所有视频文件
    for file in range(len(file_list)):
        print(VIDEO_FILE_DIR + str(file_list[file]))
        if not IS_DEX_TEST:
            video_capture = cv2.VideoCapture(VIDEO_FILE_DIR + str(file_list[file]))
            # 逐帧读取
            frame_count = 0
            while True:
                ret, frame = video_capture.read()
                frame_count += 1
                if frame_count % 2 != 0 and frame_count != 0:
                    continue
                if frame is None:   # 判断视频流中的帧是否存在
                    break
                if frame.shape[1] > frame.shape[0]:     # 强制竖屏
                    frame = cv2.transpose(frame)
                    frame = cv2.flip(frame, 1)
                frame = cv2.resize(frame, (REVISE_WIDTH, REVISE_HEIGHT))
                copy = frame
                frame = cv2.resize(frame, (int(REVISE_WIDTH / RESCALE_RATE), int(REVISE_HEIGHT / RESCALE_RATE)))  # 先缩小二分之一进行检测，再映射到原来的尺度空间中
                kp = gu.get_fast_corner(frame, FAST_RADIUS)
                kep_points = cv2.KeyPoint_convert(kp)
                point_set = []
                for i in range(len(kep_points)):
                    point_set.append([int(kep_points[i][0]), int(kep_points[i][1])])
                background = np.zeros((REVISE_WIDTH, REVISE_HEIGHT))
                ther, point_set = CKA.weighted_cluster(background, point_set, ks=KERNEL_SIZE, max_length=SET_MAX_LENGTH)
                # 对坐标点聚类
                for i in range(len(point_set)):
                    # 图像切片
                    x1 = int(point_set[i][1] * RESCALE_RATE) - int(MODEL_WIDTH / 2)
                    y1 = int(point_set[i][0] * RESCALE_RATE) - int(MODEL_HEIGHT / 2)
                    x2 = int(point_set[i][1] * RESCALE_RATE) + int(MODEL_WIDTH / 2)
                    y2 = int(point_set[i][0] * RESCALE_RATE) + int(MODEL_HEIGHT / 2)
                    cut = copy[x1: x2, y1: y2]
                    if cut.shape[0] == MODEL_WIDTH and cut.shape[1] == MODEL_HEIGHT:
                        cut = cv2.resize(cut, (MODEL_WIDTH, MODEL_HEIGHT))
                        cut = gu.get_corner_and_area(cut)
                        # 数据集增强
                        random_flip = random.randint(-1, 1)     # 随机反转
                        random_gamma = round(random.uniform(0.6, 1.2), 1)   # 随机亮度
                        random_transpose = random.randint(0, 1)     # 随机旋转
                        if random_transpose == 0:
                            cut = cv2.transpose(cut)
                        cut = cv2.flip(cut, random_flip)
                        cut = exposure.adjust_gamma(cut, random_gamma)
                        cv2.imwrite('./DataSets/Outputs/' + str(file) + str(i) + str(frame_count) + '.bmp', cut)
                if cv2.waitKey(1) & 0xff == ord('0'):
                    continue
            video_capture.release()
        else:
            image_input = cv2.imread(VIDEO_FILE_DIR + str(file_list[file]))
            image_input = cv2.resize(image_input, (REVISE_WIDTH, REVISE_HEIGHT))
            copy = image_input
            image_input = cv2.resize(image_input, (int(REVISE_WIDTH / 2), int(REVISE_HEIGHT / 2)))  # 先缩小二分之一进行检测，再映射到原来的尺度空间中
            kp = gu.get_fast_corner(image_input, FAST_RADIUS)
            kep_points = cv2.KeyPoint_convert(kp)
            point_set = []
            for i in range(len(kep_points)):
                point_set.append([int(kep_points[i][0]), int(kep_points[i][1])])
            background = np.zeros((REVISE_WIDTH, REVISE_HEIGHT))
            ther, point_set = CKA.weighted_cluster(background, point_set, ks=KERNEL_SIZE, max_length=SET_MAX_LENGTH)
            # 对坐标点聚类
            for i in range(len(point_set)):
                # 图像切片
                x1 = int(point_set[i][1] * RESCALE_RATE) - int(MODEL_WIDTH / 2)
                y1 = int(point_set[i][0] * RESCALE_RATE) - int(MODEL_HEIGHT / 2)
                x2 = int(point_set[i][1] * RESCALE_RATE) + int(MODEL_WIDTH / 2)
                y2 = int(point_set[i][0] * RESCALE_RATE) + int(MODEL_HEIGHT / 2)
                cut = copy[x1: x2, y1: y2]
                if cut.shape[0] == MODEL_WIDTH and cut.shape[1] == MODEL_HEIGHT:
                    cut = cv2.resize(cut, (MODEL_WIDTH, MODEL_HEIGHT))
                    cut = gu.get_corner_and_area(cut, True)
                    cv2.imwrite('./DataSets/Outputs/' + str(file) + str(i) + '.bmp', cut)
        if cv2.waitKey(1) & 0xff == ord('0'):
            continue
    cv2.destroyAllWindows()
